package fusion

import (
	"bytes"
	"errors"
	"fmt"
	"io"
	"log"
	"net"
	"strings"
	"sync"
	"sync/atomic"
	"time"
)

const (
	SessionHandleSuccess = iota
	SessionHandleUnhandle
	SessionHandleWarning
	SessionHandleError
	SessionHandleKill
)

type Session struct {
	conn        net.Conn
	buffer      *sendBuffer
	sendSig     chan bool
	fragPckSN   uint32
	fragPcks    map[uint32]*NetPacket
	serviceBase *ServiceBase
	isShutdown  bool
	instAge     int
}

func (obj *Session) InitAndStartSendCoroutine(s *ServiceBase, conn net.Conn) {
	obj.instAge += 1
	obj.conn = conn
	obj.buffer = new(sendBuffer)
	obj.sendSig = make(chan bool, 1)
	obj.fragPckSN = 0
	obj.fragPcks = make(map[uint32]*NetPacket)
	obj.serviceBase = s
	obj.isShutdown = false
	obj.startSendCoroutine()
}

func (obj *Session) RunRecvCoroutine(rpc_resp_opcode int, packet_callback func(*NetPacket) int) {
	var instAge = obj.instAge
	for !obj.isShutdown {
		var pck, err = obj.RecvPacket()
		if err != nil {
			log.Printf("read game `%s` packet failed, %s.\n", obj, err)
			break
		}
		if pck.Opcode != rpc_resp_opcode {
			obj.serviceBase.Tasks <- func() {
				if !obj.isShutdown && obj.instAge == instAge {
					if !obj.checkHandleStatus(packet_callback(pck), pck.Opcode) {
						obj.Shutdown(false)
					}
				}
			}
		} else {
			obj.serviceBase.RPCMgr.Reply(&pck.NetBuffer, ReadRPCRespMetaInfo(pck))
		}
	}
}

func (obj *Session) Shutdown(isKill bool) {
	if !obj.isShutdown {
		obj.isShutdown = true
		if isKill {
			obj.conn.Close()
		} else {
			obj.conn.SetWriteDeadline(time.Now().Add(time.Second * 30))
		}
		NotifyNonBlock(obj.sendSig)
	}
}

func (obj *Session) ShutdownAndStopSendCoroutine(isKill bool) {
	obj.Shutdown(isKill)
	obj.waitStopSendCoroutine()
}

func (obj *Session) String() string {
	return obj.conn.RemoteAddr().String()
}

func (obj *Session) GetIP() string {
	var s = obj.conn.RemoteAddr().String()
	return s[:strings.LastIndex(s, ":")]
}

func (obj *Session) RecvPacket() (*NetPacket, error) {
	for {
		opcode, length, err := ReadNetPacketHeaderFromReader(obj.conn)
		if err != nil {
			return nil, err
		}
		pck := NetPacket{Opcode: opcode}
		n, err := pck.buffer.ReadFrom(io.LimitReader(obj.conn, int64(length)))
		if n != int64(length) || err != nil {
			return nil, If(n != int64(length), io.ErrUnexpectedEOF, err).(error)
		}
		if pck.Opcode != OPCODE_LARGE_PACKET {
			return &pck, nil
		}
		if fragPck := obj.doRecvFragmentPacket(&pck); fragPck != nil {
			return fragPck, nil
		}
	}
}

func (obj *Session) SendPacket(pck *NetPacket) error {
	if pck.GetReadableSize() <= PacketMaxBufferSize {
		obj.buffer.writePacket(pck)
	} else {
		obj.sendOverflowPacket(pck)
	}
	NotifyNonBlock(obj.sendSig)
	return nil
}

func (obj *Session) SendPacketWithData(pck *NetPacket, data []byte) error {
	if pck.GetReadableSize()+len(data) <= PacketMaxBufferSize {
		obj.buffer.writePacketWithData(pck, data)
	} else {
		obj.sendOverflowPacketWithData(pck, data)
	}
	NotifyNonBlock(obj.sendSig)
	return nil
}

func (obj *Session) SendPacketWithPkt(pck *NetPacket, pkt *NetPacket) error {
	var extraPktLen = PacketHeaderSize + pkt.GetReadableSize()
	if pck.GetReadableSize()+extraPktLen <= PacketMaxBufferSize {
		obj.buffer.writePacketWithPkt(pck, pkt)
	} else {
		obj.sendOverflowPacketWithPkt(pck, pkt)
	}
	NotifyNonBlock(obj.sendSig)
	return nil
}

func (obj *Session) RPCInvoke(pck *NetPacket, cb func(*NetBuffer, int32, bool), timeout int64) error {
	sn := obj.serviceBase.RPCMgr.AddReq(cb, timeout)
	var args NetBuffer
	args.Write(sn)
	if err := obj.SendPacketWithData(pck, args.GetReadableBytes()); err != nil {
		obj.serviceBase.RPCMgr.CancelReq(sn, RPCErrorFailed)
		return err
	}
	return nil
}

func (obj *Session) RPCBlockInvoke(pck *NetPacket, timeout int64) (*RPCBlockContext, error) {
	var ctx RPCBlockContext
	ctx.cond = sync.NewCond(&ctx.mutex)
	err := obj.RPCInvoke(pck, func(pck *NetBuffer, err int32, eof bool) {
		ctx.mutex.Lock()
		ctx.args.PushBack(&RPCCBArgs{pck, err, eof})
		ctx.mutex.Unlock()
		ctx.cond.Signal()
	}, timeout)
	if err != nil {
		return nil, err
	}
	return &ctx, nil
}

func (obj *Session) RPCReply(pck *NetPacket, info *RPCReqMetaInfo, err int32, eof bool) error {
	var args NetBuffer
	args.Write(info.sn, err, eof)
	if err := obj.SendPacketWithData(pck, args.GetReadableBytes()); err != nil {
		return err
	}
	return nil
}

func (obj *Session) RPCReplySimple(pck *NetPacket, info *RPCReqMetaInfo) error {
	return obj.RPCReply(pck, info, RPCErrorNone, true)
}

func (obj *Session) doRecvFragmentPacket(pck *NetPacket) *NetPacket {
	var fragPckSize = pck.GetReadableSize()
	var fragPckSN uint32
	pck.Read(&fragPckSN)
	var fragPck, isOK = obj.fragPcks[fragPckSN]
	if isOK {
		fragPck.AppendBytes(pck.GetReadableBytes())
	} else {
		fragPck, obj.fragPcks[fragPckSN] = pck, pck
	}
	if fragPckSize < PacketMaxBufferSize {
		delete(obj.fragPcks, fragPckSN)
		opcode, totalPckSize :=
			ReadLargeNetPacketHeader(fragPck.Next(LargePacketHeaderSize))
		if totalPckSize != fragPck.GetReadableSize() {
			panic(fmt.Errorf("Packet size isn't equal(%d,%d)",
				totalPckSize, fragPck.GetReadableSize()))
		}
		fragPck.Opcode = opcode
		return fragPck
	}
	return nil
}

func (obj *Session) sendOverflowPacket(pck *NetPacket) error {
	return obj.doSendFragmentPacket(pck.Opcode, [][]byte{pck.GetReadableBytes()})
}

func (obj *Session) sendOverflowPacketWithData(pck *NetPacket, data []byte) error {
	return obj.doSendFragmentPacket(pck.Opcode, [][]byte{pck.GetReadableBytes(), data})
}

func (obj *Session) sendOverflowPacketWithPkt(pck *NetPacket, pkt *NetPacket) error {
	var header bytes.Buffer
	WriteNetPacketHeader(&header,
		pkt.Opcode, GetOverflowNetPacketDataSize(pkt.GetReadableSize()))
	return obj.doSendFragmentPacket(pck.Opcode,
		[][]byte{pck.GetReadableBytes(), header.Bytes(), pkt.GetReadableBytes()})
}

func (obj *Session) doSendFragmentPacket(opcode int, datas [][]byte) error {
	var totalPacketSize int
	for _, data := range datas {
		totalPacketSize += len(data)
	}

	var fragPck = NetPacket{Opcode: OPCODE_LARGE_PACKET}
	fragPck.Write(atomic.AddUint32(&obj.fragPckSN, 1))
	fragPrefixSize := fragPck.GetReadableSize()
	var header bytes.Buffer
	WriteLargeNetPacketHeader(&header, opcode, totalPacketSize)
	fragPck.AppendBytes(header.Bytes())

	var isResidualData = true
	for _, data := range datas {
		for len(data) > 0 {
			freeLen := PacketMaxBufferSize - fragPck.GetReadableSize()
			availLen := MinInt(freeLen, len(data))
			if availLen < freeLen && availLen < totalPacketSize {
				fragPck.AppendBytes(data)
			} else {
				obj.buffer.writePacketWithData(&fragPck, data[:availLen])
				isResidualData = availLen >= freeLen
				fragPck.Truncate(fragPrefixSize)
			}
			data = data[availLen:]
			totalPacketSize -= availLen
		}
	}

	if isResidualData {
		obj.buffer.writePacket(&fragPck)
	}
	if totalPacketSize != 0 {
		panic(errors.New("SendFragmentPacket: UnhandlePacketSize != 0"))
	}
	return nil
}

func (obj *Session) startSendCoroutine() {
	AllClosers.Store(obj.conn, true)
	go SafeHandler(func() {
		defer func() {
			AllClosers.Delete(obj.conn)
			obj.Shutdown(true)
		}()
	loop:
		for {
			if !WaitBlock(obj.sendSig) {
				break
			}
			for {
				var data = obj.buffer.nextData()
				if data == nil {
					if obj.isShutdown && !obj.buffer.hasData() {
						break loop
					}
					break
				}
				var _, err = obj.conn.Write(data)
				if err == nil {
					continue
				}
				log.Printf("write `%s` packet failed, %s.\n", obj, err)
				break loop
			}
		}
	})()
}

func (obj *Session) waitStopSendCoroutine() {
	for {
		if _, isOK := AllClosers.Load(obj.conn); isOK {
			time.Sleep(time.Millisecond)
		} else {
			break
		}
	}
}

func (obj *Session) checkHandleStatus(status int, opcode int) bool {
	switch status {
	case SessionHandleSuccess:
	case SessionHandleUnhandle:
		log.Printf("unhandle `%s` packet[%d].\n", obj, opcode)
	case SessionHandleWarning:
		log.Printf("handle `%s` packet[%d] warning.\n", obj, opcode)
	case SessionHandleError:
		log.Printf("handle `%s` packet[%d] error.\n", obj, opcode)
	case SessionHandleKill:
		fallthrough
	default:
		log.Printf("Fatal error occurred when processing opcode[%d], "+
			"the session[%s] has been removed.", opcode, obj)
		return false
	}
	return true
}
